############################################
# NOTE: This code only produces Figure 3.
# Also note that the results might change slightly if you do not set a seed or use a different one.
############################################

sink("scripts/forest_script_log.txt", append = TRUE) 

if (!require("pacman")) {
  install.packages("pacman")
  library(pacman)
}

pacman::p_load(grf, foreign, readstata13, stargazer, haven, dplyr, ggplot2, tidyverse, tidyr, broom, estimatr, WeightIt)

set.seed(1234) 

rm(list = ls())
setwd("~/Dropbox/Russia Ukraine EU/Replication") # set your own working directory here
dat = read_rds("data/dat.rds")

new_names <- c("BE", "DK", "DE", "GR", "ES", "FR", "IE", "IT", "LU", "NL", "PT", "UK", "AT", "SE", "FI", "CY", "CZ", "EE", "HU", "LV", "LT", "MT", "PO", "SK", "SI", "BG", "RO", "HR", "Other", "No")
dat <- dat %>%
  rename_at(vars(starts_with("q1_")), ~new_names) %>%
  mutate(across(starts_with("q1_"), as.factor))

nation <- dat %>%
  dplyr::select(starts_with("q1_")) %>%
  mutate(nationality = names(.)[max.col(.)]) %>%
  pull(nationality)

dat <- cbind(dat, nationality = nation)

dat <- dat %>%
  unite(region, starts_with("region"), remove = TRUE, na.rm = TRUE)

dat <- dat %>%
  mutate(Date = as.Date("2022-02-20") + days(p1),
         day = as.numeric(difftime(Date, as.Date("2022-02-20"), units = "days")),
         Treatment = ifelse(Date > as.Date("2022-02-23"), 1, 0))

dat <- dat %>%
  mutate(female = ifelse(d10 == 2, 1, 0),
         age = d11,
         university = ifelse(d9a %in% 7:9, 1, 0),
         occupation = as.factor(d15a_r2),
         self_employed = ifelse(d15a_r2 == 1, 1, 0),
         manager = ifelse(d15a_r2 == 2, 1, 0),
         white_collar = ifelse(d15a_r2 == 3, 1, 0),
         manual_worker = ifelse(d15a_r2 == 4, 1, 0),
         house_person = ifelse(d15a_r2 == 5, 1, 0),
         unemployed = ifelse(d15a_r2 == 6, 1, 0),
         retired = ifelse(d15a_r2 == 7, 1, 0),
         student = ifelse(d15a_r2 == 8, 1, 0),
         internetuse = netuse,
         marital_status = as.factor(d7r),
         married = ifelse(marital_status == 1, 1, 0),
         community_rural = ifelse(d25 == 1, 1, 0),
         community_small = ifelse(d25 == 2, 1, 0),
         community_city = ifelse(d25 == 3, 1, 0),
         ideology_num = na_if(d1, 11) %>% na_if(12),
         non_eu = ifelse(nationality == "Other", 1, 0),
         eu_support = (d19_2 - 1) / 3)

dat <- dat %>%
  mutate(country_code = as.factor(isocntry),
         cluster = as.factor(paste(Date, country_code, sep = "_")),
         isocntry = ifelse(isocntry %in% c("DE-E", "DE-W"), "DE", isocntry),
         country_name = dplyr::recode(isocntry, 
                               "AT" = "Austria", "BE" = "Belgium", "BG" = "Bulgaria",
                               "CY" = "Cyprus", "CZ" = "Czechia", "DE" = "Germany",
                               "DK" = "Denmark", "EE" = "Estonia", "ES" = "Spain",
                               "FI" = "Finland", "FR" = "France", "GR" = "Greece",
                               "HR" = "Croatia", "HU" = "Hungary", "IE" = "Ireland",
                               "IT" = "Italy", "LT" = "Lithuania", "LU" = "Luxembourg",
                               "LV" = "Latvia", "MT" = "Malta", "NL" = "Netherlands",
                               "PL" = "Poland", "PT" = "Portugal", "RO" = "Romania",
                               "SE" = "Sweden", "SI" = "Slovenia", "SK" = "Slovakia"))

covariate_names <- c("age", "female", "university", "self_employed", "manager", "white_collar",
                     "manual_worker", "house_person",
                     "unemployed", "retired", "student", "married", "community_rural",
                     "community_small", "community_city", "internetuse", "non_eu", "ideology_num")

covariate_labels <- data.frame(
  covariate = covariate_names,
  label = c("Age","Women", "University", "Self-employed",
            "Manager", "White-collar", "Manual worker",
            "House person", "Unemployed", "Retired", "Student",
            "Married", "Community rural", "Community small",
            "Community city", "Internet use", "Non-EU national",
            "Ideology"))

outcome_variable_name = c("eu_support")
treatment_variable_name = c("Treatment")

all_variables_names <- c(outcome_variable_name, treatment_variable_name, covariate_names)
df <- dat %>% dplyr::select(one_of(all_variables_names))
df <- na.omit(df)
df <- df %>% dplyr::rename(Y=eu_support,W=Treatment)

# Converting all columns to numerical
df <- data.frame(lapply(df, function(x) as.numeric(as.character(x))))

train_fraction <- 0.80
n <- dim(df)[1]
train_idx <- sample.int(n, replace=F, size=floor(n*train_fraction))
df_train <- df[train_idx,]
df_test <- df[-train_idx,]

cf <- causal_forest(
  X = as.matrix(df_train[,covariate_names]),
  Y = df_train$Y,
  W = df_train$W,
  num.trees=10000)

oob_pred      <- predict(cf, estimate.variance=TRUE)
oob_tauhat_cf <- oob_pred$predictions
oob_tauhat_cf_se <- sqrt(oob_pred$variance.estimates)

var_imp        <- c(variable_importance(cf)) 
names(var_imp) <- covariate_names
var_imp <- var_imp %>% sort(decreasing=TRUE)

num_tiles <- 4

df_train$cate  <- oob_tauhat_cf
df_train$ntile <- factor(ntile(oob_tauhat_cf, n=num_tiles))

estimated_sample_ate <- 
  lm_robust(Y ~ ntile + ntile:W, data=df_train) %>% 
  tidy() %>% 
  dplyr::filter(stringr::str_detect(term, ":W"))

estimated_aipw_ate <- 
  lapply(
    seq(num_tiles), function(w) 
      average_treatment_effect(cf, subset = df_train$ntile == w, target.sample = "control")
  ) %>% bind_rows

combined_estimates <- 
  bind_rows(
    estimated_sample_ate %>% mutate(type = "lm_robust") %>% dplyr::select(-outcome, -df, -statistic, - p.value),
    estimated_aipw_ate %>% dplyr::rename(std.error=std.err) %>%
      mutate(
        type  = "aipw",
        term = estimated_sample_ate$term) %>%
      mutate(
        conf.low = estimate - 1.96*std.error,
        conf.high = estimate + 1.96*std.error)
  )

list(cf = cf,
     df_train = df_train, 
     X = as.matrix(df_train[,covariate_names]),
     oob_tauhat_cf = oob_tauhat_cf, 
     var_imp = var_imp, 
     ntile_estimates = combined_estimates)

fitted_vals <- function(var_of_interest, model = test){
  
  df_train <- model$df_train
  cf <- model$cf
  
  is_continuous <- (length(unique(df_train[var_of_interest][[1]])) > 5)
  if(is_continuous) {
    x_grid <- quantile(df_train[var_of_interest][[1]], probs = seq(0, 1, length.out = 5))
  } else {
    x_grid <- sort(unique(df_train[var_of_interest][[1]]))
  }
  
  df_grid <-  setNames(data.frame(x_grid), var_of_interest)
  
  other_covariates <- covariate_names[!covariate_names %in% var_of_interest]
  df_median <- df_train %>% dplyr::select(all_of(other_covariates)) %>% summarise_all(median) 
  df_eval <- crossing(df_median, df_grid)
  
  pred <- predict(cf, newdata=df_eval[,covariate_names], estimate.variance=TRUE)
  df_eval$tauhat <- pred$predictions
  df_eval$se <- sqrt(pred$variance.estimates)
  
  df_eval %>% arrange(var_of_interest) %>%
    mutate(var_of_interest = as.factor(as.numeric(df_eval[var_of_interest][[1]])))
}

hat_matters <- data.frame(covariate = names(var_imp), value = var_imp)
hat_matters_labels <- left_join(hat_matters, covariate_labels, by = "covariate")

hat_matters_labels <- hat_matters_labels %>%
  mutate(category = case_when(
    covariate %in% c("age", "female", "married", "non_eu") ~ "Demographics",
    covariate %in% c("manual_worker", "unemployed", "retired", "white_collar", "manager", "student", "self_employed", "house_person") ~ "Employment",
    covariate %in% c("community_small", "community_city", "community_rural") ~ "Community type",
    covariate == "university" ~ "Education",
    covariate == "ideology_num" ~ "Political orientation",
    covariate %in% c("internetuse") ~ "Media use",
    TRUE ~ "Other"
  ))

hat_matters_labels <- hat_matters_labels %>%
  mutate(label = fct_reorder(label, value, .desc = FALSE))

p_f3 <- ggplot(hat_matters_labels, aes(value, label, fill = category, shape = category)) + 
  geom_point() +
  scale_x_continuous(name = "Variable Importance") +
  ylab("Covariate") +
  theme_bw() +
  theme(strip.text.y = element_text(angle = 0, hjust = 1)) +
  scale_y_reordered() +
  guides(fill = guide_legend(title = "Category"), 
         shape = guide_legend(title = "Category"))
p_f3
ggsave(plot = p_f3,
       filename = "out/hte_plot.pdf",
       width = 15,
       height = 12,
       units = "in")

sink()
